import os
import numpy as np
from omegaconf import OmegaConf
from model import zoo
import torch
import pytorch_lightning as pl
from skimage.metrics import structural_similarity as ssim, peak_signal_noise_ratio as psnr
from tqdm import tqdm
import cv2
from data.loader import load_dataset
from data.utils import perspective_camera_collate_fn
from model.raymarcher_lebesgue import LebesgueRaymarcher, ImplicitRendererDict
from model.raymarcher_simple import EmissionAbsorptionDictRaymarcher
from pytorch3d.renderer import NDCGridRaysampler


def test_conf(conf):
    '''
    It runs the test loop for the model.
    
    :param conf: the OmegaConf object that contains all the parameters
    '''
    ckpt_torch = torch.load(f'{conf.trained_ckpt}', map_location='cuda')['state_dict']

    exp_name = '.'.join(conf.conf_path.split('/')[-2:]) # left only subfolder names
    exp_name = '.'.join(exp_name.split('.')[:-1]) # avoid .yaml extension
    exp_name = conf.get('exp_name', exp_name)


    NerfModel = getattr(zoo, conf.model_name)

    nerf = NerfModel(conf.model).cuda().eval()
    ckpt_torch = {k[5:]: v for k, v in ckpt_torch.items()}
    nerf.load_state_dict(ckpt_torch)

    dataset = load_dataset(conf.data_conf, split='test')

    raysampler_grid = NDCGridRaysampler(
        image_height=conf.data_conf.render_height,
        image_width=conf.data_conf.render_width,
        n_pts_per_ray=conf.data.n_pts_per_ray,
        min_depth=conf.data_conf.min_depth,
        max_depth=conf.data_conf.max_depth,
    ).cuda()

    if conf.trainer_name == 'trainer_lebesgue':
        raymarcher = LebesgueRaymarcher().cuda()
    else:
        raymarcher = EmissionAbsorptionDictRaymarcher().cuda()

    renderer_grid = ImplicitRendererDict(
        raysampler=raysampler_grid, raymarcher=raymarcher,
        stratified_resamling=conf.data.get('stratified_sampling', True),
        eps=conf.model.get('t_eps', 0.0),
    ).cuda()

    dir_name = f'results/{exp_name}/lego'
    os.makedirs(f'{dir_name}/pred/', exist_ok=True)
    os.makedirs(f'{dir_name}/target/', exist_ok=True)
    frames = []
    psnrs = []
    ssims = []

    for cam_i in tqdm(range(len(dataset.target_cameras))):
        part = dataset.target_cameras[cam_i:cam_i + 1]
        target = dataset.target_images[cam_i].cpu().numpy()
        cameras = perspective_camera_collate_fn([{'target_camera': part[i], 'target_image' : None} for i in range(len(part))])['cameras'].cuda()

        with torch.no_grad():
            outputs = renderer_grid(
                cameras=cameras, 
                volumetric_function=nerf.batched_forward,
            )
            frame = outputs[0][..., :3].cpu().numpy()[0]
            frames.append(frame)
            ssims.append(ssim(frame, target, channel_axis=-1, data_range=1.0))
            psnrs.append(psnr(frame, target))
            cv2.imwrite(f'{dir_name}/pred/{cam_i}.png', 255 * frame[..., [2, 1, 0]])
            cv2.imwrite(f'{dir_name}/target/{cam_i}.png', 255 * target[..., [2, 1, 0]])
    print(f'PSNR: {np.mean(psnrs)}, SSIM: {np.mean(ssims)}')
    print(f'Predicted images stored in: {dir_name}/pred/, gt images in: {dir_name}/target/')
    

if __name__ == '__main__':
    cli_conf = OmegaConf.from_cli()
    conf = OmegaConf.load(cli_conf.conf_path)
    conf = OmegaConf.merge(conf, cli_conf)
    data_conf = OmegaConf.load(conf.get('data_conf_path', 'data_configs/cow.yaml'))
    conf = OmegaConf.merge(conf, data_conf)
    test_conf(conf)